{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Benchmarking Thinc layers with a custom `benchmark` layer\n",
    "\n",
    "This notebook shows how to write a `benchmark` layer that can wrap any layer(s) in your network and that **logs the execution times** of the initialization, forward pass and backward pass. The benchmark layer can also be mapped to an operator like `@` to make it easy to add debugging to your network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install \"thinc>=8.0.0a0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To log the results, we first set up a custom logger using Python's `logging` module. You could also just print the stats instead, but using `logging` is cleaner, since it lets other users modify the logger's behavior more easily, and separates the logs from other output and write it to a file (e.g. if you're benchmarking several layers during training). The following logging config will output the date and time, the name of the logger and the logged results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "\n",
    "logger = logging.getLogger(\"thinc:benchmark\")\n",
    "if not logger.hasHandlers():  # prevent Jupyter from adding multiple loggers\n",
    "    formatter = logging.Formatter('%(asctime)s %(name)s  %(message)s', datefmt=\"%Y-%m-%d %H:%M:%S\")\n",
    "    handler = logging.StreamHandler()\n",
    "    handler.setFormatter(formatter)\n",
    "    logger.addHandler(handler)\n",
    "    logger.setLevel(logging.DEBUG)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's a minimalistic time logger that can be initialized with the name of a given layer, and can track several events (e.g. `\"forward\"` and `\"backward\"`). When the `TimeLogger.end` method is called, the output is formatted nicely and the elapsed time is logged with the logger name and colored label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from timeit import default_timer\n",
    "from wasabi import color\n",
    "\n",
    "\n",
    "class TimeLogger:\n",
    "    def __init__(self, name):\n",
    "        self.colors = {\"forward\": \"green\", \"backward\": \"blue\"}\n",
    "        self.name = name\n",
    "        self.timers = {}\n",
    "        \n",
    "    def start(self, name):\n",
    "        self.timers[name] = default_timer()\n",
    "        \n",
    "    def end(self, name):\n",
    "        result = default_timer() - self.timers[name]\n",
    "        label = f\"{name.upper():<8}\"\n",
    "        label = color(label, self.colors.get(name), bold=True)\n",
    "        logger.debug(f\"{self.name:<12} | {label} | {result:.6f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `benchmark` layer now has to wrap the forward pass, backward pass and initialization of the layer it wraps and log the execution times. It then returns a Thinc model instance with the custom `forward` function and a custom `init` function. We'll also allow setting a custom `name` to make it easier to tell multiple wrapped benchmark layers apart."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from thinc.api import Model\n",
    "        \n",
    "def benchmark(layer, name=None):\n",
    "    name = name if name is not None else layer.name\n",
    "    t = TimeLogger(name)\n",
    "    \n",
    "    def init(model, X, Y):\n",
    "        t.start(\"init\")\n",
    "        result = layer.initialize(X, Y)\n",
    "        t.end(\"init\")\n",
    "        return result\n",
    "        \n",
    "    def forward(model, X, is_train):\n",
    "        t.start(\"forward\")\n",
    "        layer_Y, layer_callback = layer(X, is_train=is_train)\n",
    "        t.end(\"forward\")\n",
    "        \n",
    "        def backprop(dY):\n",
    "            t.start(\"backward\")\n",
    "            result = layer_callback(dY)\n",
    "            t.end(\"backward\")\n",
    "            return result\n",
    "        \n",
    "        return layer_Y, backprop\n",
    "        \n",
    "    return Model(f\"benchmark:{layer.name}\", forward, init=init) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Usage examples\n",
    "\n",
    "### Using the `benchmark` layer as a function\n",
    "\n",
    "We can now wrap one or more layers (including nested layers) with the `benchmark` function. This is the original model:\n",
    "\n",
    "```python\n",
    "model = chain(Linear(1), Linear(1))\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-01-22 00:09:50 thinc:benchmark  linear       | \u001b[1mINIT    \u001b[0m | 0.000236\n",
      "2020-01-22 00:09:50 thinc:benchmark  linear       | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000424\n",
      "2020-01-22 00:09:50 thinc:benchmark  outer        | \u001b[1mINIT    \u001b[0m | 0.004570\n",
      "2020-01-22 00:09:50 thinc:benchmark  linear       | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000116\n",
      "2020-01-22 00:09:50 thinc:benchmark  outer        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.016403\n",
      "2020-01-22 00:09:50 thinc:benchmark  linear       | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000239\n",
      "2020-01-22 00:09:50 thinc:benchmark  outer        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.003203\n"
     ]
    }
   ],
   "source": [
    "import numpy\n",
    "from thinc.api import chain, Linear\n",
    "\n",
    "X = numpy.zeros((1, 2), dtype=\"f\")\n",
    "\n",
    "model = benchmark(chain(benchmark(Linear(1)), Linear(1)), name=\"outer\")\n",
    "model.initialize(X=X)\n",
    "Y, backprop = model(X, is_train=False)\n",
    "dX = backprop(Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using the `benchmark` layer as an operator\n",
    "\n",
    "Alternatively, we can also use `Model.define_operators` to map `benchmark` to an operator like `@`. The left argument of the operator is the first argument passed into the function (the layer) and the right argument is the second argument (the name). The following example wraps the whole network (two chained `Linear` layers) in a benchmark layer named `\"outer\"`, and the first `Linear` layer in a benchmark layer named `\"first\"`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-01-22 00:09:55 thinc:benchmark  first        | \u001b[1mINIT    \u001b[0m | 0.000315\n",
      "2020-01-22 00:09:55 thinc:benchmark  first        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000106\n",
      "2020-01-22 00:09:55 thinc:benchmark  outer        | \u001b[1mINIT    \u001b[0m | 0.009043\n",
      "2020-01-22 00:09:55 thinc:benchmark  first        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000104\n",
      "2020-01-22 00:09:55 thinc:benchmark  outer        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.003850\n",
      "2020-01-22 00:09:55 thinc:benchmark  first        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000133\n",
      "2020-01-22 00:09:55 thinc:benchmark  outer        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.010088\n"
     ]
    }
   ],
   "source": [
    "from thinc.api import Model\n",
    "\n",
    "with Model.define_operators({\">>\": chain, \"@\": benchmark}):\n",
    "    model = (Linear(1) @ \"first\" >> Linear(1)) @ \"outer\"\n",
    "    \n",
    "model.initialize(X=X)\n",
    "Y, backprop = model(X, is_train=True)\n",
    "dX = backprop(Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using the `benchmark` layer during training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1mINIT    \u001b[0m | 0.001160\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000258\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1mINIT    \u001b[0m | 0.000322\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000193\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000360\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000551\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000665\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.001340\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000253\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000772\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000252\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000265\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000196\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000274\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000170\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000223\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000364\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.001056\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000542\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000320\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000119\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.001592\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000245\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000169\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000166\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.014438\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000822\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000175\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000062\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000636\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000184\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000149\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000090\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000233\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000171\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000162\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000087\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000285\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000176\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000160\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000088\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.001057\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000239\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;2mFORWARD \u001b[0m | 0.000166\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu2        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000090\n",
      "2020-01-22 00:10:04 thinc:benchmark  relu1        | \u001b[1;38;5;4mBACKWARD\u001b[0m | 0.000238\n"
     ]
    }
   ],
   "source": [
    "from thinc.api import Model, chain, Relu, Softmax, Adam\n",
    "\n",
    "n_hidden = 32\n",
    "dropout = 0.2\n",
    "\n",
    "with Model.define_operators({\">>\": chain, \"@\": benchmark}):\n",
    "    model = (\n",
    "        Relu(nO=n_hidden, dropout=dropout) @ \"relu1\"\n",
    "        >> Relu(nO=n_hidden, dropout=dropout) @ \"relu2\"\n",
    "        >> Softmax()\n",
    "    )\n",
    "\n",
    "train_X = numpy.zeros((5, 784), dtype=\"f\")\n",
    "train_Y = numpy.zeros((540, 10), dtype=\"f\")\n",
    "\n",
    "model.initialize(X=train_X[:5], Y=train_Y[:5])\n",
    "optimizer = Adam(0.001)\n",
    "for i in range(10):\n",
    "    for X, Y in model.ops.multibatch(8, train_X, train_Y, shuffle=True):\n",
    "        Yh, backprop = model.begin_update(X)\n",
    "        backprop(Yh - Y)\n",
    "        model.finish_update(optimizer)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.2 64-bit ('.env': venv)",
   "language": "python",
   "name": "python37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}